# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Init for base architecture engine monitor register. """

import time
import os
import stat
from typing import Optional, Union, Iterable
import numpy as np

import luojianet as ms
from luojianet import save_checkpoint
from luojianet.train.callback import Callback

#from checkparam import Rel, Validator as validator
from check_param import Rel,Validator as validator
__all__ = ["LossMonitor", "ValAccMonitor"]


class LossMonitor(Callback):
    """
    Loss Monitor for classification.

    Args:
        lr_init (Union[float, Iterable], optional): The learning rate schedule. Default: None.
        per_print_times (int): Every how many steps to print the log information. Default: 1.

    Examples:
    from mindvision.engine.callback import LossMonitor
    lr = [0.01, 0.008, 0.006, 0.005, 0.002]
    monitor = LossMonitor(lr_init=lr, per_print_times=100)
    """

    def __init__(self,
                 lr_init: Optional[Union[float, Iterable]] = None,
                 per_print_times: int = 1):
        super(LossMonitor, self).__init__()
        self.lr_init = lr_init
        self.per_print_times = per_print_times
        self.last_print_time = 0

    # pylint: disable=unused-argument





def epoch_begin(self, run_context):
    """
    Record time at the beginning of epoch.

    Args:
        run_context (RunContext): Context of the process running.
    """
    self.losses = []
    self.epoch_time = time.time()




def epoch_end(self, run_context):
    """
    Print training info at the end of epoch.

    Args:
        run_context (RunContext): Context of the process running.
    """
    callback_params = run_context.original_args()
    epoch_mseconds = (time.time() - self.epoch_time) * 1000
    per_step_mseconds = epoch_mseconds / callback_params.batch_num
    print(f"Epoch time: {epoch_mseconds:5.3f} ms, "
          f"per step time: {per_step_mseconds:5.3f} ms, "
          f"avg loss: {np.mean(self.losses):5.3f}", flush=True)


# pylint: disable=unused-argument


def step_begin(self, run_context):
    """
    Record time at the beginning of step.

    Args:
        run_context (RunContext): Context of the process running.
    """
    self.step_time = time.time()




def step_end(self, run_context):
    """
    Print training info at the end of step.

    Args:
        run_context (RunContext): Context of the process running.
    """
    callback_params = run_context.original_args()
    step_mseconds = (time.time() - self.step_time) * 1000
    loss = callback_params.net_outputs

    if isinstance(loss, (tuple, list)):
        if isinstance(loss[0], ms.Tensor) and isinstance(loss[0].asnumpy(), np.ndarry):
            loss = loss[0]

    if isinstance(loss, ms.Tensor) and isinstance(loss.asnumpy(), np.ndarray):
        loss = np.mean(loss.asnumpy())

    self.losses.append(loss)
    cur_step_in_epoch = (callback_params.cur_step_num - 1) % callback_params.batch_num + 1

    # Boundary check.
    if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
        raise ValueError(f"Invalid loss, terminate training.")

    def print_info():
        lr_output = self.lr_init[callback_params.cur_step_num - 1] if isinstance(self.lr_init,
                                                                                 list) else self.lr_init
        print(f"Epoch:[{(callback_params.cur_epoch_num - 1):3d}/{callback_params.epoch_num:3d}], "
              f"step:[{cur_step_in_epoch:5d}/{callback_params.batch_num:5d}], "
              f"loss:[{loss:5.3f}/{np.mean(self.losses):5.3f}], "
              f"time:{step_mseconds:5.3f} ms, "
              f"lr:{lr_output:5.5f}", flush=True)

    if (callback_params.cur_step_num - self.last_print_time) >= self.per_print_times:
        self.last_print_time = callback_params.cur_step_num
        print_info()





class ValAccMonitor(Callback):
    """
    Monitors the train loss and the validation accuracy, after each epoch saves the
    best checkpoint file with highest validation accuracy.

    Args:
        model (ms.Model): The model to monitor.
        dataset_val (ms.dataset): The dataset that the model needs.
        num_epochs (int): The number of epochs.
        interval (int): Every how many epochs to validate and print information. Default: 1.
        eval_start_epoch (int): From which time to validate. Default: 1.
        save_best_ckpt (bool): Whether to save the checkpoint file which performs best. Default: True.
        ckpt_directory (str): The path to save checkpoint files. Default: './'.
        best_ckpt_name (str): The file name of the checkpoint file which performs best. Default: 'best.ckpt'.
        metric_name (str): The name of metric for model evaluation. Default: 'Accuracy'.
        dataset_sink_mode (bool): Whether to use the dataset sinking mode. Default: True.

    Raises:
        ValueError: If `interval` is not more than 1.

    Examples:
    import luojianet as ms
    import luojianet.nn as nn
    import luojianet.dataset as ds
    from mindvision.classification.models import lenet
    from mindvision.classification.dataset import Mnist

    net = lenet()
    opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.001, momentum=0.9)
    loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True,reduction='mean')
    model = ms.Model(net, loss,opt,metrics={"Accuracy":nn.Accuracy()})
    dataset_val = Mnist("./mnist", split="test", batch_size=32, resize=32, download=True)
    dataset_val = dataset_val.run()
    monitor = ValAccMonitor(model, dataset_val, num_epochs=10)
    """

    def __init__(self,
                 model: ms.Model,
                 dataset_val: ms.dataset,
                 num_epochs: int,
                 interval: int = 1,
                 eval_start_epoch: int = 1,
                 save_best_ckpt: bool = True,
                 ckpt_directory: str = "./",
                 best_ckpt_name: str = "best.ckpt",
                 metric_name: str = "Accuracy",
                 dataset_sink_mode: bool = True):
        super(ValAccMonitor, self).__init__()
        self.model = model
        self.dataset_val = dataset_val
        self.num_epochs = num_epochs
        self.eval_start_epoch = eval_start_epoch
        self.save_best_ckpt = save_best_ckpt
        self.metric_name = metric_name
        self.interval = validator.check_int(interval, 1, Rel.GE, "interval")
        self.best_res = 0
        self.dataset_sink_mode = dataset_sink_mode

        if not os.path.isdir(ckpt_directory):
            os.makedirs(ckpt_directory)
        self.best_ckpt_path = os.path.join(ckpt_directory, best_ckpt_name)





def apply_eval(self):
    """Model evaluation, return validation accuracy."""
    return self.model.eval(self.dataset_val, dataset_sink_mode=self.dataset_sink_mode)[self.metric_name]





def epoch_end(self, run_context):
    """
    After epoch, print train loss and val accuracy,
    save the best ckpt file with highest validation accuracy.

    Args:
        run_context (RunContext): Context of the process running.
    """
    callback_params = run_context.original_args()
    cur_epoch = callback_params.cur_epoch_num

    if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
        # Validation result
        res = self.apply_eval()

        print("-" * 20)
        print(f"Epoch: [{cur_epoch: 3d} / {self.num_epochs: 3d}], "
              f"Train Loss: [{callback_params.net_outputs.asnumpy() :5.3f}], "
              f"{self.metric_name}: {res: 5.3f}")

        def remove_ckpt_file(file_name):
            os.chmod(file_name, stat.S_IWRITE)
            os.remove(file_name)

        # Save the best ckpt file
        if res >= self.best_res:
            self.best_res = res
            if self.save_best_ckpt:
                if os.path.exists(self.best_ckpt_path):
                    remove_ckpt_file(self.best_ckpt_path)
                save_checkpoint(callback_params.train_network, self.best_ckpt_path)


# pylint: disable=unused-argument



def end(self, run_context):
    """
    Print the best validation accuracy after network training.

    Args:
        run_context (RunContext): Context of the process running.
    """
    print("=" * 80)
    print(f"End of validation the best {self.metric_name} is: {self.best_res: 5.3f}, "
          f"save the best ckpt file in {self.best_ckpt_path}", flush=True)
